import yaml
import argparse

# ===========================================================================
# LOAD CONFIGURATIONS
def get_params():
    parser = argparse.ArgumentParser(description="Continual Learning for NER")


    # =========================================================================================
    # Experimental Settings
    # =========================================================================================
    # Debug
    parser.add_argument("--is_debug", default=False, type=bool, help="if skipping the test on training and validation set")
    
    # Wandb
    parser.add_argument("--is_wandb", default=False, type=bool, help="if using wandb")
    parser.add_argument("--wandb_project", type=str, default=None, help="wandb project name")
    parser.add_argument("--wandb_entity", type=str, default=None, help="wandb project name")

    # Config
    parser.add_argument("--cfg", default="./config/default.yaml", help="Hyper-parameters") # debug: "./config/default.yaml"

    # Path
    parser.add_argument("--wandb_name", type=str, default="default", help="Experiment name")
    parser.add_argument("--logger_filename", type=str, default="train.log")
    parser.add_argument("--dump_path", type=str, default="experiments", help="Experiment saved root path")
    parser.add_argument("--seed", type=int, default=None, help="Random Seed")

    # Model
    parser.add_argument("--backbone", type=str, default="bert-base-cased", choices=['resnet18','bert-base-cased','vit_base_patch16_224', 'vit_base_patch16_224_in21k', 'vit_tiny_patch16_224_in21k','deit_small_patch16_224'], help="backbone name")
    parser.add_argument("--use_adapter", default=False, type=bool, help="If using adapter for PTMs")
    parser.add_argument("--save_ckpt", default=False, type=bool, help="If save ckpt")
    parser.add_argument("--dropout", type=float, default=0, help="dropout rate")
    parser.add_argument("--hidden_dim", type=int, default=768, help="Hidden layer dimension")
    parser.add_argument("--alpha", type=float, default=0, help="Trade-off parameter")

    # Data
    parser.add_argument("--task_name", type=str, default='IC', choices=['NER','TC','IC'], help="the task to perform continual learning")
    parser.add_argument("--dataset", type=str, default="CIFAR100", choices=['CIFAR100','five_datasets','tiny-imagenet-200','omnibenchmark','ImageNetA','ImageNetR','vtab','objectnet'], help="dataset name")
    parser.add_argument("--class_ft", type=int, default=5, help="(In one dataset mode) Number of classes in the first task")
    parser.add_argument("--class_pt", type=int, default=5, help="(In one dataset mode) Number of classes in each task")
    parser.add_argument("--n_samples", type=int, default=-1, help="conduct few-shot learning (10, 25, 40, 55, 70, 85, 100)")
    parser.add_argument("--schema", type=str, default="BIO", choices=['IO','BIO','BIOES'], help="Lable schema")

    # =========================================================================================
    # Training Settings
    # =========================================================================================
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size") 
    parser.add_argument("--max_seq_length", type=int, default=128, help="Max length for each sentence") 
    
    parser.add_argument("--lr", type=float, default=3e-5, help="Initial learning rate") 
    parser.add_argument('--final_fc_lr', default=1e-3, type=float, help='the learning rate for the final FC layer')
    parser.add_argument("--training_epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--first_training_epochs", type=int, default=5, help="Number of training epochs in first task_id (will be set as training_epochs by default)")

    parser.add_argument("--mu", type=float, default=0.9, help="Momentum")
    parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight decay")
    parser.add_argument("--scheduler", type=str, default='constant', choices=['constant','multistep'], help="scheduler")
    parser.add_argument("--scheduler_milestone", type=str, default='[50,90]', help="the milestone for multistep scheduler")

    parser.add_argument("--info_per_epochs", type=int, default=1, help="Print information every how many epochs")
    parser.add_argument("--info_per_steps", type=int, default=0, help="Print information every how many steps")
    parser.add_argument("--save_per_epochs", type=int, default=0, help="Save checkpoints every how many epochs")
    parser.add_argument("--save_per_steps", type=int, default=0, help="Save checkpoints every how many steps")
    parser.add_argument("--evaluate_interval", type=int, default=1, help="Evaluation interval")
    parser.add_argument("--is_use_last_ckpt", default=False, type=bool, help="If using last epoch model instead of the selected model according to dev acc")
    parser.add_argument("--early_stop", type=int, default=100, help="No improvement after several epoch, we stop training")

    # =========================================================================================
    # General Settings for Incremental Learning
    # =========================================================================================
    parser.add_argument("--is_rand_init", default=False, type=bool, help="If loading randomly initialized model")
    parser.add_argument("--is_imprint", default=False, type=bool, help="If rescale the new weight matrix")
    parser.add_argument("--is_fix_enc", default=False, type=bool, help="If fixing encoder")
    parser.add_argument("--is_fix_old_cls", default=False, type=bool, help="If fix the old classifier")
    parser.add_argument("--is_probing", default=False, type=bool, help="If probing the model for each epoch")
    parser.add_argument("--probing_interval", default=1, type=int, help="Probing every n epoch")
    parser.add_argument("--is_tracking", default=False, type=bool, help="If tracking the moving distance in the hidden feature space")
    parser.add_argument("--tracking_interval", default=1, type=int, help="Tracking every n epoch")
    parser.add_argument("--reserved_ratio", type=float, default=0, help="the ratio of reserved samples")

    # =========================================================================================
    # Baseline Settings
    # =========================================================================================

    parser.add_argument("--is_buffer", default=False, type=bool, help="If using buffer")
    parser.add_argument("--is_fix_budget_each_class", default=False, type=bool, help="If using fix budget for each class (buffer_size//all_classes)")
    parser.add_argument("--sampling_alg", default='herding', choices=['herding','reservior'], type=str, help="The sampling algorithm")
    parser.add_argument("--buffer_size", type=int, default=0, help="The buffer size")

    # Multi-task
    parser.add_argument("--is_multitask", default=False, type=bool, help="If using multi-task-learning")

    # Experience Replay
    parser.add_argument("--is_er", default=False, type=bool, help="If using buffer")
    parser.add_argument("--is_mix_er", default=False, type=bool, help="If combine buffer data with new data (each batch)")
    parser.add_argument("--is_combine_er", default=False, type=bool, help="If combine buffer data with new data")
    parser.add_argument("--replay_interval", type=int, default=1, help="The interval for experience replay")

    # MBPA&MBPA++
    parser.add_argument("--is_mbpa", default=False, type=bool, help="If using test time adaptation")
    parser.add_argument("--mbpa_K", default=32, type=int, help="The number of KNN in MBPA")
    parser.add_argument("--mbpa_step", default=30, type=int, help="The number of update step in MBPA")
    parser.add_argument("--mbpa_lambda", default=1e-3, type=int, help="The weight of the L2 loss in MBPA")
    parser.add_argument("--mbpa_lr", default=1e-5, type=int, help="The learning rate of MBPA")

    # LwF
    parser.add_argument("--is_LWF", default=False, type=bool, help="If using LwF (Learning without Forgetting)")
    parser.add_argument("--LWF_temperature", default=2, type=float, help="The temperature in LwF")
    parser.add_argument("--LWF_lambda", default=1, type=float, help="The weight of distill loss in LwF")

    # EWC
    parser.add_argument("--is_EWC", default=False, type=bool, help="If using EWC (Elastic Weight Consolidation)")
    parser.add_argument("--EWC_lambda", default=5000, type=float, help="The weight of distill loss in EWC")

    # Distillate&Self-Training (add temperature or not)
    parser.add_argument("--is_distill", default=False, type=bool, help="If using distillation model for baseline")
    parser.add_argument("--distill_weight", type=float, default=2, help="distillation weight for loss")
    parser.add_argument("--adaptive_distill_weight", default=False, type=bool, help="If using adaptive weight")
    parser.add_argument("--adaptive_schedule", type=str, default='root', choices=['root','linear','square'], help="The schedule for adaptive weight")
    parser.add_argument("--temperature", type=float, default=1, help="temperature of the student model")
    parser.add_argument("--ref_temperature", type=float, default=1, help="temperature of the teacher model")
    parser.add_argument("--is_ranking_loss", default=False, type=bool, help="Add ranking loss in LUCIR")
    parser.add_argument("--ranking_weight", type=float, default=5, help="weight for ranking loss")

    # LUCIR
    parser.add_argument("--is_lucir", default=False, type=bool, help="If using LUCIR as baseline")
    parser.add_argument("--lucir_lw_distill", type=float, default=50, help="Loss weight for distillation")
    parser.add_argument("--lucir_K", type=int, default=1, help="Top K for MR loss")
    parser.add_argument("--lucir_mr_dist", type=float, default=0.5, help="Margin for MR loss")
    parser.add_argument("--lucir_lw_mr", type=float, default=1, help="Loss weight for MR loss")

    # PodNet
    parser.add_argument("--is_podnet", default=False, type=bool, help="If using Podnet as baseline")
    parser.add_argument("--podnet_is_nca", default=False, type=bool, help="If using NCA loss")
    parser.add_argument("--podnet_nca_scale", type=float, default=1, help="The scaling factor for NCA")
    parser.add_argument("--podnet_nca_margin", type=float, default=0.6, help="The margin for NCA")
    parser.add_argument("--podnet_lw_pod_flat", type=float, default=1, help="Loss weight for flatten (last) feature distillation loss")
    parser.add_argument("--podnet_lw_pod_spat", type=float, default=1, help="Loss weight for intermediate feature distillation loss")
    parser.add_argument("--podnet_normalize", default=False, type=bool, help="If normalize the feature before calculating the distance")

    # BaCE 
    parser.add_argument("--is_BaCE", default=False, type=bool, help="using exponiential moving average to compute joint loss")
    parser.add_argument("--BaCE_prompt_tuning", type=bool, default=False, help="If using prompt tuning")
    parser.add_argument("--BaCE_prompt_len", type=int, default=5, help="If using prompt tuning")
    parser.add_argument("--BaCE_prompt_lr", type=float, default=0.03, help="The learning rate of prompt")
    parser.add_argument("--BaCE_k", type=int, default=5, help="Number of KNN")
    parser.add_argument("--BaCE_beta", type=float, default=1.0, help="The decay for BaCE")
    parser.add_argument("--BaCE_W0_bg", type=float, default=0.95, help="The init weight of anchor for BaCE")
    parser.add_argument("--BaCE_W0", type=float, default=0.95, help="The weight of anchor for BaCE")
    parser.add_argument("--BaCE_dist_threshold", type=float, default=1e8, help="The distance threshold for selecting KNN")
    parser.add_argument("--BaCE_lambda", type=float, default=1.0, help="The weight of distillation loss")
    parser.add_argument("--BaCE_lambda_type", type=str, default='sqrt', choices=['constant','linear','sqrt'], help="How the weight is computed in different CIL steps")
    parser.add_argument("--BaCE_weight_type", type=str, default='average', choices=['average','distance'], help="How the KNN weight is computed")
    parser.add_argument("--BaCE_distill_new", default=False, type=bool, help="If distilling the samples from new classes")
    parser.add_argument("--BaCE_distill_new_only", default=False, type=bool, help="If only distilling the samples from new classes")
    parser.add_argument("--BaCE_distill_joint", default=False, type=bool, help="If distilling both scores and logits")
    parser.add_argument("--BaCE_is_plot", default=False, type=bool, help="If plot KNNs")
    
    # CLSER
    parser.add_argument("--is_CLSER", default=False, type=bool, help="If using CLSER")
    parser.add_argument("--CLSER_alpha_1", default=0.999, type=float, help="The momentum of stable model")
    parser.add_argument("--CLSER_alpha_2", default=0.999, type=float, help="The momentum of plastic model")
    parser.add_argument("--CLSER_freq_1", default=0.7, type=float, help="The update frequency of stable model")
    parser.add_argument("--CLSER_freq_2", default=0.9, type=float, help="The update frequency of plastic model")
    parser.add_argument("--CLSER_lambda", default=0.1, type=float, help="The hyper-paramerter of the MSE term")

    # MEMO
    parser.add_argument("--is_MEMO", default=False, type=bool, help="If using MEMO")


    params = parser.parse_args()

    if params.cfg is not None and params.cfg!='None':
        with open(params.cfg) as f:
            config = yaml.safe_load(f)
            for k, v in config.items():
                # for parameters set in the args
                params.__setattr__(k,v)

    return params